{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "jacobian_determinant.ipynb",
      "provenance": [],
      "private_outputs": true,
      "collapsed_sections": [
        "Zeu-2daF2lPh"
      ],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Zeu-2daF2lPh",
        "colab_type": "text"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iLohXFLL2pm9",
        "colab_type": "code",
        "colab": {},
        "cellView": "form"
      },
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "E4xiYigexMtQ"
      },
      "source": [
        "# Jacobian determinant for chanage of variables\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dEFXJ31C2QEW",
        "colab_type": "text"
      },
      "source": [
        "_Notebook orignially contributed by: [MarkDaoust](https://github.com/markdaoust)_\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github.com/tensorflow/examples/blob/master/community/en/jacobian_determinant.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/examples/tree/master/community/en/jacobian_determinant.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "N40uBBnmq6dd"
      },
      "source": [
        "This notebook uses a jacobian determinant to do a change of variables in a probabilty distribution, and make some nice visualizations.\n",
        "\n",
        "This is just a quick walkthrough, you should probably use [TensorFlow probability](https://tensorflow.org/Probability) for any serious applications. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "f5iLo1iYrhBA"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "IqR2PQG4ZaZ0",
        "colab": {}
      },
      "source": [
        "import tensorflow as tf\n",
        "import numpy as np\n",
        "\n",
        "import matplotlib as mpl\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "mpl.rcParams['figure.figsize'] = (12, 8)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Hx1iAPSdkheq"
      },
      "source": [
        "##  Create the distribution\n",
        "Build a `(x0,x1)` grid."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "zUSdUOWwEy8y",
        "colab": {}
      },
      "source": [
        "x0 = tf.linspace(-12.0, 12.0, 240+1)[:, tf.newaxis]\n",
        "x1 = tf.linspace(-10.0, 10.0, 200+1)[tf.newaxis, :]\n",
        "\n",
        "xs = tf.stack(tf.meshgrid(x0, x1), axis=-1)\n",
        "xs_flat = tf.reshape(xs, [-1,2])\n",
        "\n",
        "batch_shape = xs.shape[:-1]\n",
        "print(batch_shape)\n",
        "print(xs.shape)\n",
        "print(xs_flat.shape)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "JobKcjqykvXc"
      },
      "source": [
        "Create a multivariate normal distribution over `x` "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "VBuCiyDRHqPJ",
        "colab": {}
      },
      "source": [
        "from tensorflow_probability import distributions as tfd\n",
        "\n",
        "dist = tfd.MultivariateNormalFullCovariance(\n",
        "    loc=[2,1], covariance_matrix=[[9,-2],[-2,4]])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "PUG-RyY6k9QO"
      },
      "source": [
        "Calculate the probability on the `x` grid, and take a few samples"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "qpOSsegak6Q-",
        "colab": {}
      },
      "source": [
        "px = dist.prob(xs_flat)\n",
        "sample_xs = dist.sample(sample_shape=50)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "lFojvVeQlIGl"
      },
      "source": [
        "Plot the density and samples"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "xIw5XwDb_M4D",
        "colab": {}
      },
      "source": [
        "def plot_sheet(coords_flat, value_flat, batch_shape, sample = None, **kwargs):\n",
        "  c0 = tf.reshape(coords_flat[..., 0], batch_shape)\n",
        "  c1 = tf.reshape(coords_flat[..., 1], batch_shape)\n",
        "\n",
        "  value = tf.reshape(value_flat, batch_shape)\n",
        "\n",
        "  if sample is not None:\n",
        "    plt.scatter(sample[:,0], sample[:,1], c='w', marker='.', zorder=1)\n",
        "  plt.pcolormesh(c0, c1, value, zorder=-1, shading='gouraud', **kwargs)\n",
        "\n",
        "  plt.axis('off')\n",
        "  plt.gca().set_aspect('equal')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "IfSihi4sNny6",
        "colab": {}
      },
      "source": [
        "plot_sheet(xs_flat, px, batch_shape, sample=sample_xs)\n",
        "plt.title('A MultiVariate probability density')\n",
        "cbar = plt.colorbar()\n",
        "cbar.set_label('Density $1/x^2$')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "_GCJJs-OlRpk"
      },
      "source": [
        "## Transform the coordinates\n",
        "\n",
        "Transform the `x` coordinates to `u`"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "lmtmThLpUlbA",
        "colab": {}
      },
      "source": [
        "def transform(xs):\n",
        "  x0 = xs[..., 0] \n",
        "  x1 = xs[..., 1] \n",
        "\n",
        "  u0 = x0 + tf.sin(x1)*0.9\n",
        "  u1 = x1 + tf.cos(x0)*0.9\n",
        "\n",
        "  us = tf.stack([u0,u1], axis=-1)\n",
        "\n",
        "  return us"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "4p7Du5CYA8vO",
        "colab": {}
      },
      "source": [
        "us_flat = transform(xs_flat)\n",
        "sample_us = transform(sample_xs)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "qLUYGl0NleDb"
      },
      "source": [
        "Plot with the transformed coordinates, the density map is wrong."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "rs0uOrnhNsvg",
        "colab": {}
      },
      "source": [
        "plot_sheet(us_flat, px, batch_shape, sample=sample_us)\n",
        "plt.title('This is wrong.')\n",
        "cbar = plt.colorbar()\n",
        "cbar.set_label('Wrong units: $1/x^2$') "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "MtyjUMlelo-8"
      },
      "source": [
        "## Apply the jacobian determinant\n",
        "\n",
        "Calculate the Jacobian determinant at each point, to scale the density."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "VrqEw9wsF229",
        "colab": {}
      },
      "source": [
        "with tf.GradientTape() as tape:\n",
        "  tape.watch(xs_flat)\n",
        "  us_flat = transform(xs_flat)\n",
        "\n",
        "js = tape.batch_jacobian(us_flat, xs_flat)\n",
        "js_scale = tf.linalg.det(js)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "vr5n_EqZmL0-"
      },
      "source": [
        "The [Jacobian determinant](https://www.youtube.com/watch?v=p46QWyHQE6M) tells you locally how much each differential area has expanded or contracted, and how much the density changed."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "agzbkcXaNy_B",
        "colab": {}
      },
      "source": [
        "plot_sheet(us_flat, 1/js_scale, batch_shape)\n",
        "plt.title('Jacobian density change')\n",
        "cbar = plt.colorbar()\n",
        "cbar.set_label('Area scale factor $u^2/x^2$') "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "lzYZtzLQnFtN"
      },
      "source": [
        "Divide the density in `x` by the Jacobian determinant to get the density in `u`."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "HSOvbSC_OCU8",
        "colab": {}
      },
      "source": [
        "pu = px/js_scale\n",
        "\n",
        "plot_sheet(us_flat, pu, batch_shape, sample=sample_us)\n",
        "plt.title('Transformed density')\n",
        "cbar = plt.colorbar()\n",
        "cbar.set_label('Density $1/u^2$') "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "dEqqfN36nW7d"
      },
      "source": [
        "Here is everything together:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "KcsxU3k8I5Cc",
        "colab": {}
      },
      "source": [
        "vmax = pu.numpy().max()\n",
        "plt.subplot(2,2,1)\n",
        "plot_sheet(xs_flat, px, batch_shape, vmax=vmax)\n",
        "plt.title('A MultiVariate probability density')\n",
        "cbar = plt.colorbar()\n",
        "cbar.set_label('Density $1/x^2$')\n",
        "\n",
        "plt.subplot(2,2,2)\n",
        "plot_sheet(us_flat, px, batch_shape, vmax=vmax)\n",
        "plt.title('This is wrong')\n",
        "cbar = plt.colorbar()\n",
        "cbar.set_label('Wrong units: $1/x^2$') \n",
        "\n",
        "plt.subplot(2,2,3)\n",
        "plot_sheet(us_flat, 1/js_scale, batch_shape)\n",
        "plt.title('Jacobian density change')\n",
        "cbar = plt.colorbar()\n",
        "cbar.set_label('Area scale factor $x^2/u^2$') \n",
        "\n",
        "plt.subplot(2,2,4)\n",
        "plot_sheet(us_flat, pu, batch_shape, vmax=vmax)\n",
        "plt.title('Transformed density')\n",
        "cbar = plt.colorbar()\n",
        "cbar.set_label('Density $1/u^2$') "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "C8jBDf-rnc17"
      },
      "source": [
        "## Re-grid the new density\n",
        "\n",
        "Now create a new grid over `u`, and re-evaluate the density, to make integration easier.   "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "JfqVcJppJBzv",
        "colab": {}
      },
      "source": [
        "u0_mesh = tf.linspace(-12.0, 12, 240+1)\n",
        "u1_mesh = tf.linspace(-10.0, 10.0, 200+1)\n",
        "\n",
        "us_mesh = tf.stack(tf.meshgrid(u0_mesh, u1_mesh), axis=-1)\n",
        "new_batch_shape = us_mesh.shape[:-1]\n",
        "us_mesh_flat = tf.reshape(us_mesh, [-1,2])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "PtFt1qVR5po-",
        "colab": {}
      },
      "source": [
        "import scipy.interpolate\n",
        "pu_mesh_flat = scipy.interpolate.griddata(us_flat, pu, us_mesh_flat, fill_value=0.0)\n",
        "pu_mesh = tf.reshape(pu_mesh_flat, new_batch_shape)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "1DquYoub8a-c",
        "colab": {}
      },
      "source": [
        "plt.subplot(2,1,1)\n",
        "plot_sheet(us_flat, pu, batch_shape, vmax=vmax)\n",
        "plt.title('Transformed density')\n",
        "cbar = plt.colorbar()\n",
        "cbar.set_label('Density $1/u^2$') \n",
        "\n",
        "plt.subplot(2,1,2)\n",
        "plot_sheet(us_mesh_flat, pu_mesh, new_batch_shape)\n",
        "plt.title('Transformed density, re-meshed')\n",
        "cbar = plt.colorbar()\n",
        "cbar.set_label('Density $1/u^2$') "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "-zfXEwt_sO95"
      },
      "source": [
        "## Calculate the marginal distributions"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "vRSwT2lpn3_e"
      },
      "source": [
        "With the density in u reevaluated on a nice grid, it's possible to integrate and get nice results.\n",
        "\n",
        "Integrate to get the two marginals. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "pfOktyEzFfqW",
        "colab": {}
      },
      "source": [
        "import scipy.integrate\n",
        "pu1 = scipy.integrate.trapz(tf.reshape(pu_mesh, new_batch_shape), u0_mesh)\n",
        "pu0 = scipy.integrate.trapz(tf.reshape(pu_mesh, new_batch_shape).numpy().T, u1_mesh)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Rsip3ylRocL2"
      },
      "source": [
        "Integrate the marginals to sanity-check that the total probability mass is `~=1.0`"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "WhvcIrksoU1J",
        "colab": {}
      },
      "source": [
        "print(scipy.integrate.trapz(pu1, u1_mesh))\n",
        "print(scipy.integrate.trapz(pu0, u0_mesh))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "wfheqrtmo1Bh"
      },
      "source": [
        "Plot the marginals, they're still surprisingly Gaussian."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "_rtM-pLnHYzd",
        "colab": {}
      },
      "source": [
        "def plot_extras(us, sheet, u0=None, v0=None, u1=None, v1=None):\n",
        "  axes = [[None, None],\n",
        "          [None, None]]\n",
        "\n",
        "  batch_shape = sheet.shape\n",
        "  fig = plt.figure()\n",
        "\n",
        "  ax_joint = fig.add_axes([0.1, 0.1, 0.6, 0.6])\n",
        "  axes[1][0]=ax_joint\n",
        "  plot_sheet(us, sheet, batch_shape)\n",
        "  plt.axis('on')\n",
        "  plt.xlim([-12,12])\n",
        "  plt.ylim([-10,10])\n",
        "  position = ax_joint.get_position()\n",
        "\n",
        "  if u0 is not None:\n",
        "    u0_ax = fig.add_axes([position.x0, position.y1+0.005, position.width, 0.15],\n",
        "                        sharex=ax_joint)\n",
        "    axes[0][0]=u0_ax\n",
        "    plt.plot(u0, v0)\n",
        "    plt.axis('off')\n",
        "\n",
        "  if u1 is not None:\n",
        "    u1_ax = fig.add_axes([position.x1+0.005, position.y0, 0.15, position.height],\n",
        "                sharey=ax_joint)\n",
        "    axes[1][1]=u1_ax\n",
        "    plt.plot(v1, u1)\n",
        "    plt.axis('off')\n",
        "\n",
        "  return axes"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "rRPvZAzzIHQw",
        "colab": {}
      },
      "source": [
        "axes = plot_extras(us_mesh, pu_mesh, u0_mesh, pu0, u1_mesh, pu1)\n",
        "_ = axes[0][0].set_title('Joint and marginals')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "SkSP5k5BpD_j"
      },
      "source": [
        "## Calculate the conditional distributions\n",
        "\n",
        "Divide by the marginals to calculate the conditional distribution. \n",
        "\n",
        "Here is `p(u1|u0)` and a slice of it."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "HsytxnYhVwMs",
        "colab": {}
      },
      "source": [
        "at_u0 = 1.0\n",
        "u1_slice = scipy.interpolate.griddata(\n",
        "    us_mesh_flat, pu_mesh_flat, (at_u0*np.ones_like(u1_mesh), u1_mesh), fill_value=0.0)\n",
        "\n",
        "\n",
        "axes = plot_extras(us_mesh, pu_mesh/pu0[None, :], u0_mesh, pu0, u1_mesh, u1_slice)\n",
        "plt.sca(axes[1][0])\n",
        "plt.plot([at_u0,at_u0], plt.ylim(), color='w')\n",
        "axes[1][0].set_xlabel('p(u1|u0)')\n",
        "axes[0][0].set_title('p(u0)')\n",
        "_ = axes[1][1].set_title(f'p(u1|u0={at_u0})')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "gE84kHzXpusM"
      },
      "source": [
        "Here is `p(u0|u1)` and a slice of it."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Sng3Q87nV6Lh",
        "colab": {}
      },
      "source": [
        "at_u1 = 2.0\n",
        "u0_slice = scipy.interpolate.griddata(\n",
        "    us_mesh_flat, pu_mesh_flat, (u0_mesh, at_u1*np.ones_like(u0_mesh)), fill_value=0.0)\n",
        "\n",
        "axes = plot_extras(us_mesh, pu_mesh/pu1[:,None], u0_mesh, u0_slice, u1_mesh, pu1)\n",
        "plt.sca(axes[1][0])\n",
        "plt.plot(plt.xlim(), [at_u1,at_u1], color='w')\n",
        "axes[1][0].set_xlabel('p(u0|u1)')\n",
        "axes[1][1].set_title('p(u1)')\n",
        "_ = axes[0][0].set_title(f'p(u0|u1={at_u1})')"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}
